{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2\n",
    "import random\n",
    "import numpy as np\n",
    "from matplotlib import pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def imgCrop(img, x_start, x_end, y_start, y_end):\n",
    "    img_crop = img[x_start : x_end, y_start : y_end]\n",
    "    return img_crop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def imgColorShift(img, b_shift, g_shift, r_shift):\n",
    "    B, G, R = cv2.split(img)\n",
    "    if b_shift == 0:\n",
    "        pass\n",
    "    elif b_shift > 0:\n",
    "        lim = 255 - b_shift\n",
    "        B[B > lim] = 255\n",
    "        B[B <= lim] = (b_shift + B[B <= lim]).astype(img.dtype)\n",
    "        \n",
    "    if g_shift == 0:\n",
    "        pass\n",
    "    elif g_shift > 0:\n",
    "        lim = 255 - g_shift\n",
    "        G[G > lim] = 255\n",
    "        G[G <= lim] = (g_shift + G[G <= lim]).astype(img.dtype)  \n",
    "        \n",
    "    if r_shift == 0:\n",
    "        pass\n",
    "    elif r_shift > 0:\n",
    "        lim = 255 - r_shift\n",
    "        R[R > lim] = 255\n",
    "        R[R <= lim] = (r_shift + R[R <= lim]).astype(img.dtype)    \n",
    "    img_merge = cv2.merge((B, G, R))\n",
    "    return img_merge"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def imgSimilarityTransform(img, angle, scale, center_x, center_y):\n",
    "    M = cv2.getRotationMatrix2D((center_x, center_y), angle, scale)\n",
    "    img_rotated = cv2.warpAffine(img, M, (img.shape[1], img.shape[0]))\n",
    "    return img_rotated"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def imgPerspectiveTransform(img, random_margin):\n",
    "    height, width, channels = img.shape\n",
    "    # warp:\n",
    "    x1 = random.randint(-random_margin, random_margin)\n",
    "    y1 = random.randint(-random_margin, random_margin)\n",
    "    x2 = random.randint(width - random_margin - 1, width - 1)\n",
    "    y2 = random.randint(-random_margin, random_margin)\n",
    "    x3 = random.randint(width - random_margin - 1, width - 1)\n",
    "    y3 = random.randint(height - random_margin - 1, height - 1)\n",
    "    x4 = random.randint(-random_margin, random_margin)\n",
    "    y4 = random.randint(height - random_margin - 1, height - 1)\n",
    "\n",
    "    dx1 = random.randint(-random_margin, random_margin)\n",
    "    dy1 = random.randint(-random_margin, random_margin)\n",
    "    dx2 = random.randint(width - random_margin - 1, width - 1)\n",
    "    dy2 = random.randint(-random_margin, random_margin)\n",
    "    dx3 = random.randint(width - random_margin - 1, width - 1)\n",
    "    dy3 = random.randint(height - random_margin - 1, height - 1)\n",
    "    dx4 = random.randint(-random_margin, random_margin)\n",
    "    dy4 = random.randint(height - random_margin - 1, height - 1)\n",
    "\n",
    "    pts1 = np.float32([[x1, y1], [x2, y2], [x3, y3], [x4, y4]])#做投射变换需要4个点，原来的点\n",
    "    pts2 = np.float32([[dx1, dy1], [dx2, dy2], [dx3, dy3], [dx4, dy4]])#做投射变换需要4个点，目标点\n",
    "    M_warp = cv2.getPerspectiveTransform(pts1, pts2)\n",
    "    img_warp = cv2.warpPerspective(img, M_warp, (width, height))\n",
    "    return img_warp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def imgDataAugmentation(img, crop, shift, center, angle, scale, warp_value):\n",
    "    img_crop = imgCrop(img,*crop)\n",
    "    img_color_shift = imgColorShift(img_crop, *shift)\n",
    "    img_rotated = imgSimilarityTransform(img_color_shift, angle, scale, *center)\n",
    "    img_warp = imgPerspectiveTransform(img_rotated,warp_value)\n",
    "    return img_warp\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img_input = cv2.imread('Lenna.png')\n",
    "cv2.imshow('Input Image', img_input)\n",
    "#img_input, crop, color_shift(bgr), center(x,y), angle, scale, warp_value \n",
    "img_output = imgDataAugmentation(img_input, [0,512,0,512], [-20,20,-20], (img_input.shape[1] / 2, img_input.shape[0] / 2), 30, 1, 60)\n",
    "cv2.imshow('Output Image', img_output)\n",
    "key = cv2.waitKey(0)\n",
    "if key == 27:\n",
    "    cv2.destroyAllWindows()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
